import numpy as np
import pickle as pkl
import os, re
from collections import defaultdict

from scipy.special import softmax

import matplotlib.pyplot as plt
import matplotlib.patches as patches

## custom color palette
lblue = (40/255,103/255,178/255)
cred  = (177/255, 4/255, 14/255)
nyllw = (245/255,213/255,71/255)# Naples yellow

import torch
import torchvision
import torchvision.transforms as transforms
# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter

def proj(w, j, eps):
    d = len(w)
    wj = w[j]
    w1 = np.sort(np.delete(w,j))[::-1]
    means = (w[j] + np.hstack([0,np.cumsum(w1)]))/np.arange(1,d+1)
    shrink_val = means - (eps/np.sqrt(2))/np.arange(1,d+1)
    ind = np.min(np.where(np.hstack([w1, -float('Inf')]) < shrink_val))
    w_proj = w.copy()
    if ind>0:
        w_proj[w_proj > shrink_val[ind]] = shrink_val[ind]
        w_proj[j] = shrink_val[ind] + eps/np.sqrt(2)
    return(w_proj, np.linalg.norm(w-w_proj))

def kstar(w, eps):
    d = len(w)
    k = 0
    o = np.argsort(-w)
    for i in range(d):
        j = o[i]
        _, dist = proj(w, j, eps)
        if dist < eps:
            k+=1
        else:
            break
    return k

def s_argmax(w):
    return(set([np.argmax(w)]))

def s(w, eps):
    k = kstar(w, eps)
    order = np.argsort(w)[::-1]
    return(set(order[:k]))

def s_topk(w, k=2):
    return(set(np.argsort(w)[::-1][:k]))

def s_thresh(w, t=.8):
    o = np.argsort(w)[::-1]
    k = np.where(np.cumsum(w[o]) >= t)[0].min()+1
    return(set(o[:k]))

def s_ndd(w, beta=1):
    L = len(w)
    o = np.argsort(w)[::-1]
    w_= w[o]
    i = 0 
    Delta = 1 - (beta**2+1)/(beta**2+np.arange(1, L+1))*np.cumsum(w_)
    I = np.where(Delta[:-1] - Delta[1:] < 0)[0]
    
    if I.size > 0:
        return(set(o[:(I.min()+1)]))
    else:
        return(set(o))
    
def g(s, delta=1.6, gamma=.6):
    return(delta/s - gamma/s**2)

def s_u(w, gvec):
    L = len(w)
    o = np.argsort(w)[::-1]
    w_= w[o]
    i = 0 
    P = np.cumsum(w_)
    U = gvec*P
    I = np.where(U[:-1] - U[1:] > 0)[0]
    if I.size > 0:
        return(set(o[:(I.min()+1)]))
    else:
        return(set(o))
    
def fun(w, method):
    if method=='inflate':
        return(s(w, EPS))
    elif method=='argmax':
        return(s_argmax(w))
    elif method=='topk':
        return(s_topk(w))
    elif method=='threshold':
        return(s_thresh(w))
    elif method=='ndd':
        return(s_ndd(w))
    elif method=='SVBOP-u65':
        seq = np.arange(1,len(w)+1)
        return(s_u(w, g(seq)))
    elif method=='SVBOP-u80':
        seq = np.arange(1,len(w)+1)
        return(s_u(w, g(seq, 2.2, 1.2)))
    elif method=='prec':
        seq = np.arange(1,len(w)+1)
        return(s_u(w, 1/seq))
    
def u50(y, S):
    k = len(S)
    return((y in S)*1./k)

def u65(y, S):
    k = len(S)
    return((y in S)*(1.6/k - .6/k**2))

def u80(y, S):
    k = len(S)
    return((y in S)*(2.2/k - 1.2/k**2))

# load full
fp = 'full/1'
with open(os.path.join(fp, 'weights.pkl'), 'rb') as f:
    w_full = pkl.load(f)
n = 60000
m = int(n/2)
N, K = w_full.shape
B = 1000

fp = 'loo'
RUN_IDs = [fn for fn in os.listdir(fp) if '_' not in fn and '.' not in fn]
RUN_IDs = [fn for fn in RUN_IDs 
           if os.path.exists(os.path.join(fp, fn, 'weights.pkl'))]
N_loo = len(RUN_IDs)

# get the loo ids 
ns = []
for ID in RUN_IDs:
    with open(os.path.join(fp, ID, 'indices.pkl'), 'rb') as f:
        ns.append(list(set(range(n)).difference(set(pkl.load(f))))[0])

w_loo = np.zeros((N_loo, N, K))
for b, ID in enumerate(RUN_IDs):
    with open(os.path.join(fp, ID, 'weights.pkl'), 'rb') as f:
        w_loo[b] = pkl.load(f)

from tqdm import tqdm
fp = '../Fashion MNIST experiment/subbagging_bigb/subbagging'

RUN_IDs = []
for fn in tqdm(os.listdir(fp)):
    if (os.path.exists(os.path.join(fp, fn, 'weights.pkl')) 
        and os.path.exists(os.path.join(fp, fn, 'indices.pkl'))):
        RUN_IDs.append(fn)
# RUN_IDs = os.listdir(fp)

B_tot = len(RUN_IDs)#+1

bags = np.zeros((B_tot, m), dtype=int)
w_bags = np.zeros((B_tot, N, K))
for b, ID in tqdm(enumerate(RUN_IDs)):
    with open(os.path.join(fp, ID, 'indices.pkl'), 'rb') as f:
        bags[b] = pkl.load(f)
    with open(os.path.join(fp, ID, 'weights.pkl'), 'rb') as f:
        w_bags[b] = pkl.load(f)
        
np.random.seed(1234567891)
inds = np.random.choice(B_tot, B, replace=False)
w_subb = np.zeros(w_bags.shape[1:])
for b in tqdm(inds):
    w_subb += w_bags[b]
w_subb /= B

w_subb_= []
for i in ns:
    inds = np.where(np.all(bags!=i,1))[0]
    w_subb_.append(np.mean(w_bags[np.random.choice(inds, B, replace=False)], 0))
w_subb_ = np.array(w_subb_)

transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# Create datasets for training & validation, download if necessary
validation_set = torchvision.datasets.FashionMNIST('./data', 
                                                   train=False, 
                                                   transform=transform, 
                                                   download=False)
validation_loader = torch.utils.data.DataLoader(validation_set, 
                                                batch_size=N, 
                                                shuffle=False)
for data in validation_loader:
    _, Y = data
Y = np.array(Y)

METHODS = ['argmax', 'inflate', 'topk', 'threshold', 'ndd', 'SVBOP-u65', 'SVBOP-u80']
EPS = .05
def dd():
    return defaultdict(list)
results = defaultdict(dd)

for j in tqdm(range(N)): 
    for wt_type in ['base', 'subbagged']:
        w = w_full[j] if wt_type=='base' else w_subb[j]

        for method in METHODS:
            S = fun(w, method)
            ct = 0
            loo_prec = 0
            for i in range(N_loo):
                w_ = w_loo[i,j] if wt_type=='base' else w_subb_[i,j]
                S_ = fun(w_, method)
                ct += 0 if (S&S_) else 1
                loo_prec += (set([Y[j]]) == S_)                        
            results['stability'][(method, wt_type)].append(ct/N_loo)
            results['precision'][(method, wt_type)].append((set([Y[j]]) == S))
            results['precision_loo'][(method, wt_type)].append(loo_prec/N_loo)
            results['accuracy'][(method, wt_type)].append( (Y[j] in S))
            results['size'][(method, wt_type)].append(len(S))
            results['u50'][(method, wt_type)].append(u50(Y[j], S))
            results['u65'][(method, wt_type)].append(u65(Y[j], S))
            results['u80'][(method, wt_type)].append(u80(Y[j], S))

for wt_type in ['base', 'subbagged']:
    for method in METHODS:
        ct, cb, tot = 0, 0, 0
        for j in tqdm(range(N)):
            # check if this method abstains
            if results['size'][(method, wt_type)][j] > 1:
                tot += 1
                ct += results['accuracy'][(method, wt_type)][j]
                cb += results['accuracy'][('argmax', wt_type)][j]
                
        if tot > 0:
            incorrectness_on_abstention_over_top1_on_abstention = (cb/tot)/(ct/tot) 
            se = np.sqrt(incorrectness_on_abstention_over_top1_on_abstention*
                         (1-incorrectness_on_abstention_over_top1_on_abstention)
                         / N )
        else:
            incorrectness_on_abstention_over_top1_on_abstention = 'NA'
            se = 'NA'
        print(wt_type, method, 
              incorrectness_on_abstention_over_top1_on_abstention, se)
        results['sup-infl'][(method,wt_type)] = incorrectness_on_abstention_over_top1_on_abstention
        results['sup-infl-se'][(method,wt_type)] = se

formal_names = {
    'argmax' : 'argmax',
    'inflate' : 'argmax$^\\varepsilon$',
    'topk' : '$\\text{top-}2$',
    'threshold' : 'Thresholding $\\Gamma^*_{0.8}$',
    'ndd' : 'NDC for $F_1$',
    'SVBOP-u65' :  'SVBOP$_{u_{65}}$', 
    'SVBOP-u80' : 'SVBOP$_{u_{80}}$', 
    'prec' :  'SVBOP$_{u_{50}}$',
}

metrics = ['precision', 'size', 'u65', 'u80']
summarize = lambda x : (np.mean(x), np.std(x)/np.sqrt(len(x)))

print("""
\\rowcolors{2}{gray!25}{white}
\\begin{table}[t]
\\centering
\\resizebox{\\textwidth}{!}{%
    {\\renewcommand{\\arraystretch}{1.5} 
    \\begin{tabular}{c c c c c c c c}
    \\rowcolor{gray!60}
    Sel. rule & Algo & $\\beta_{\\text{CS}}~\\nearrow$ & $\\beta_{\\text{size}}~\\searrow$ & $u_{65}~\\nearrow$ & $u_{80}~\\nearrow$ & $\\beta_{\\text{sup. infl.}}~\\searrow$ \\\\\\hline
""")

for method in METHODS:
    for wt_type in ['base', 'subbagged']:
        print('\t', 
              formal_names[method], ' & ', # selection rule
              '$\\widetilde{\\mathcal{A}}_m$' if wt_type=='subbagged' else '$\\mathcal{A}$', ' & ', # algo 
               end=''
             )
        for metric in metrics:
            print('%.3f (%.3f)' % summarize(results[metric][(method, wt_type)]), ' & ', end='')
        if results['sup-infl'][(method,wt_type)] == 'NA':
            print(' - ')
        else: 
            print('%.3f (%.3f)' % (results['sup-infl'][(method, wt_type)], results['sup-infl-se'][(method, wt_type)]), end='')
        print(' \\\\')
print("""
\\end{tabular}}
    }
    \\medskip
    \\caption{Results on the Fashion MNIST data set. The table displays the average precision, 
              $\\beta_{\\textnormal{CS}}$, and size, $\\beta_{\\textnormal{size}}$, as defined 
              in \\Cref{sec:experiments}, as well as utility-discounted predictive accuracies 
              $u_{65}$, $u_{80}$ and the superfluous inflation, all defined in this section. 
              For each metric, the symbol $\\nearrow$ indicates that higher values are desirable, 
              while $\\searrow$ indicates that lower values are desirable. Results for the base 
              algorithm are in white, and results for the subbagged algorithm are in gray. 
              Standard errors are shown in parentheses.}
    \\label{tab:results-extended}
\\end{table}""")

from tueplots import bundles

plt.rcParams.update(bundles.neurips2024())
plt.rcParams.update({'font.size' : 10,
                    'axes.titlesize': 10,
                    'lines.linewidth' : 2,})

B = 1000

def plot_sf(arr, c, lab=None, ls='-'):
    size = len(arr)
    plt.step(np.sort(arr)[::-1], np.arange(size)/size, 
             where='post', c=c, label=lab, ls=ls)
def plot_sf(arr, c, lab=None, ls='-'):
    size = len(arr)
    plt.step(np.hstack([1, np.sort(arr)[::-1]]), 
             np.hstack([1/(2*size), np.arange(1,1+size)/size]), 
             where='post', c=c, label=lab, ls=ls)
    
plt.figure(figsize=(4.5,2.781152949374527))

cs = [cred, lblue, nyllw, 'g', 'k']
    
plot_sf(results['stability'][('argmax','base')], cs[0],  
        formal_names['argmax'] + '$\circ \mathcal{A}$', '--')
plot_sf(results['stability'][('inflate','base')], cs[1],  
        formal_names['inflate'] + '$\circ \mathcal{A}$', '--')
plot_sf(results['stability'][('argmax','subbagged')], cs[0], 
        formal_names['argmax'] + '$\circ \widetilde{\\mathcal{A}}_m$', '-')
plot_sf(results['stability'][('inflate','subbagged')], cs[1], 
        formal_names['inflate'] + '$\circ \widetilde{\\mathcal{A}}_m$', '-')

    

plt.semilogy()
# plt.loglog()
plt.legend(loc='upper right')

plt.gca().spines['top'].set_color('none')
plt.gca().spines['right'].set_color('none')

plt.xticks([0, .5, 1])

plt.xlabel('$\\delta$', fontsize=12)
plt.ylabel('$\\frac{1}{N}\sum_{j\in [N]}1\\{\\delta_j > \\delta\\}$', 
           fontsize=12, labelpad=5)

L = 10

plt.ylim([(2/3)/N,2])
plt.xlim([-.005,1.01])

plt.title('Stability comparison with $B=1,000$ bags')
plt.tight_layout()

plt.savefig('instability.pdf')
plt.show()